home *** CD-ROM | disk | FTP | other *** search
/ Clickx 115 / Clickx 115.iso / software / tools / windows / tails-i386-0.16.iso / live / filesystem.squashfs / usr / share / pyshared / gnutls / connection.py < prev    next >
Encoding:
Python Source  |  2008-04-10  |  19.3 KB  |  501 lines

  1. # Copyright (C) 2007-2008 AG Projects. See LICENSE for details.
  2. #
  3.  
  4. """GNUTLS connection support"""
  5.  
  6. __all__ = ['X509Credentials', 'ClientSession', 'ServerSession', 'ServerSessionFactory']
  7.  
  8. from time import time
  9. from socket import SHUT_RDWR as SOCKET_SHUT_RDWR
  10.  
  11. from _ctypes import PyObj_FromPtr
  12. from ctypes import *
  13.  
  14. from gnutls.validators import *
  15. from gnutls.constants import *
  16. from gnutls.crypto import *
  17. from gnutls.errors import *
  18.  
  19. from gnutls.library.constants import GNUTLS_SERVER, GNUTLS_CLIENT, GNUTLS_CRT_X509
  20. from gnutls.library.constants import GNUTLS_CERT_INVALID, GNUTLS_CERT_REVOKED, GNUTLS_CERT_INSECURE_ALGORITHM
  21. from gnutls.library.constants import GNUTLS_CERT_SIGNER_NOT_FOUND, GNUTLS_CERT_SIGNER_NOT_CA
  22. from gnutls.library.constants import GNUTLS_AL_FATAL, GNUTLS_A_BAD_CERTIFICATE
  23. from gnutls.library.constants import GNUTLS_A_UNKNOWN_CA, GNUTLS_A_INSUFFICIENT_SECURITY
  24. from gnutls.library.constants import GNUTLS_A_CERTIFICATE_EXPIRED, GNUTLS_A_CERTIFICATE_REVOKED
  25. from gnutls.library.constants import GNUTLS_NAME_DNS
  26. from gnutls.library.types     import gnutls_certificate_credentials_t, gnutls_session_t, gnutls_x509_crt_t
  27. from gnutls.library.types     import gnutls_certificate_server_retrieve_function
  28. from gnutls.library.functions import *
  29.  
  30.  
  31. @gnutls_certificate_server_retrieve_function
  32. def _retrieve_server_certificate(c_session, retr_st):
  33.     session = PyObj_FromPtr(gnutls_session_get_ptr(c_session))
  34.     identity = session.credentials.select_server_identity(session)
  35.     retr_st.contents.type = GNUTLS_CRT_X509
  36.     retr_st.contents.deinit_all = 0
  37.     if identity is None:
  38.         retr_st.contents.ncerts = 0
  39.     else:
  40.         retr_st.contents.ncerts = 1
  41.         retr_st.contents.cert.x509.contents = identity.cert._c_object
  42.         retr_st.contents.key.x509 = identity.key._c_object
  43.     return 0
  44.  
  45.  
  46. class _ServerNameIdentities(dict):
  47.     """Used internally by X509Credentials to map server names to X509 identities for the server name extension"""
  48.     def __init__(self, identities):
  49.         dict.__init__(self)
  50.         for identity in identities:
  51.             self.add(identity)
  52.     def add(self, identity):
  53.         for name in identity.cert.alternative_names.dns:
  54.             self[name.lower()] = identity
  55.         for ip in identity.cert.alternative_names.ip:
  56.             self[ip] = identity
  57.         subject = identity.cert.subject
  58.         if subject.CN is not None:
  59.             self[subject.CN.lower()] = identity
  60.     def get(self, server_name, default=None):
  61.         server_name = server_name.lower()
  62.         if server_name in self:
  63.             return self[server_name]
  64.         for name in (n for n in self if n.startswith('*.')):
  65.             suffix = name[1:]
  66.             if server_name.endswith(suffix) and '.' not in server_name[:-len(suffix)]:
  67.                 return self[name]
  68.         return default
  69.  
  70.  
  71. class X509Credentials(object):
  72.     DH_BITS  = 1024
  73.     RSA_BITS = 1024
  74.  
  75.     dh_params  = None
  76.     rsa_params = None
  77.  
  78.     def __new__(cls, *args, **kwargs):
  79.         c_object = gnutls_certificate_credentials_t()
  80.         gnutls_certificate_allocate_credentials(byref(c_object))
  81.         instance = object.__new__(cls)
  82.         instance.__deinit = gnutls_certificate_free_credentials
  83.         instance._c_object = c_object
  84.         return instance
  85.  
  86.     @method_args((X509Certificate, none), (X509PrivateKey, none), list_of(X509Certificate), list_of(X509CRL), list_of(X509Identity))
  87.     def __init__(self, cert=None, key=None, trusted=[], crl_list=[], identities=[]):
  88.         """Credentials contain a X509 certificate, a private key, a list of trusted CAs and a list of CRLs (all optional).
  89.         An optional list of additional X509 identities can be specified for applications that need more that one identity"""
  90.         if cert and key:
  91.             gnutls_certificate_set_x509_key(self._c_object, byref(cert._c_object), 1, key._c_object)
  92.         elif (cert, key) != (None, None):
  93.             raise ValueError("Specify neither or both the certificate and private key")
  94.         gnutls_certificate_server_set_retrieve_function(self._c_object, _retrieve_server_certificate)
  95.         self._max_depth = 5
  96.         self._max_bits  = 8200
  97.         self._type = CRED_CERTIFICATE
  98.         self._cert = cert
  99.         self._key = key
  100.         self._identities = tuple(identities)
  101.         self._trusted = ()
  102.         self.add_trusted(trusted)
  103.         self.crl_list = crl_list
  104.         self.server_name_identities = _ServerNameIdentities(identities)
  105.         if cert and key:
  106.             self.server_name_identities.add(X509Identity(cert, key))
  107.         self.session_params = SessionParams(self._type)
  108.  
  109.     def __del__(self):
  110.         self.__deinit(self._c_object)
  111.  
  112.     # Methods to alter the credentials at runtime
  113.  
  114.     @method_args(list_of(X509Certificate))
  115.     def add_trusted(self, trusted):
  116.         size = len(trusted)
  117.         if size > 0:
  118.             ca_list = (gnutls_x509_crt_t * size)(*[cert._c_object for cert in trusted])
  119.             gnutls_certificate_set_x509_trust(self._c_object, cast(byref(ca_list), POINTER(gnutls_x509_crt_t)), size)
  120.             self._trusted = self._trusted + tuple(trusted)
  121.  
  122.     def generate_dh_params(self, bits=DH_BITS):
  123.         reference = self.dh_params ## keep a reference to preserve it until replaced
  124.         X509Credentials.dh_params  = DHParams(bits)
  125.         del reference
  126.  
  127.     def generate_rsa_params(self, bits=RSA_BITS):
  128.         reference = self.rsa_params ## keep a reference to preserve it until replaced
  129.         X509Credentials.rsa_params = RSAParams(bits)
  130.         del reference
  131.  
  132.     # Properties
  133.  
  134.     @property
  135.     def cert(self):
  136.         return self._cert
  137.  
  138.     @property
  139.     def key(self):
  140.         return self._key
  141.  
  142.     @property
  143.     def identities(self):
  144.         return self._identities
  145.  
  146.     @property
  147.     def trusted(self):
  148.         return self._trusted
  149.  
  150.     def _get_crl_list(self):
  151.         return self._crl_list
  152.     @method_args(list_of(X509CRL)) 
  153.     def _set_crl_list(self, crl_list):
  154.         self._crl_list = tuple(crl_list)
  155.     crl_list = property(_get_crl_list, _set_crl_list)
  156.     del _get_crl_list, _set_crl_list
  157.  
  158.     def _get_max_verify_length(self):
  159.         return self._max_depth
  160.     @method_args(int) 
  161.     def _set_max_verify_length(self, max_depth):
  162.         gnutls_certificate_set_verify_limits(self._c_object, self._max_bits, max_depth)
  163.         self._max_depth = max_depth
  164.     max_verify_length = property(_get_max_verify_length, _set_max_verify_length)
  165.     del _get_max_verify_length, _set_max_verify_length
  166.  
  167.     def _get_max_verify_bits(self):
  168.         return self._max_bits
  169.     @method_args(int) 
  170.     def _set_max_verify_bits(self, max_bits):
  171.         gnutls_certificate_set_verify_limits(self._c_object, max_bits, self._max_depth)
  172.         self._max_bits = max_bits
  173.     max_verify_bits = property(_get_max_verify_bits, _set_max_verify_bits)
  174.     del _get_max_verify_bits, _set_max_verify_bits
  175.  
  176.     # Methods to select and validate certificates
  177.  
  178.     def check_certificate(self, cert, cert_name='certificate'):
  179.         """Verify activation, expiration and revocation for the given certificate"""
  180.         now = time()
  181.         if cert.activation_time > now:
  182.             raise CertificateExpiredError("%s is not yet activated" % cert_name)
  183.         if cert.expiration_time < now:
  184.             raise CertificateExpiredError("%s has expired" % cert_name)
  185.         for crl in self.crl_list:
  186.             crl.check_revocation(cert, cert_name=cert_name)
  187.  
  188.     def select_server_identity(self, session):
  189.         """Select which identity the server will use for a given session. The default selection algorithm uses
  190.         the server name extension. A subclass can overwrite it if a different selection algorithm is desired."""
  191.         server_name = session.server_name
  192.         if server_name is not None:
  193.             return self.server_name_identities.get(server_name)
  194.         elif self.cert and self.key:
  195.             return self ## since we have the cert and key attributes we can behave like a X509Identity
  196.         else:
  197.             return None
  198.  
  199.  
  200. class SessionParams(object):
  201.     _default_kx_algorithms = {
  202.         CRED_CERTIFICATE: (KX_RSA, KX_DHE_DSS, KX_DHE_RSA),
  203.         CRED_ANON: (KX_ANON_DH,)}
  204.     _all_kx_algorithms = {
  205.         CRED_CERTIFICATE: set((KX_RSA, KX_DHE_DSS, KX_DHE_RSA, KX_RSA_EXPORT)),
  206.         CRED_ANON: set((KX_ANON_DH,))}
  207.  
  208.     def __new__(cls, credentials_type):
  209.         if credentials_type not in cls._default_kx_algorithms:
  210.             raise TypeError("Unknown credentials type: %r" % credentials_type)
  211.         return object.__new__(cls)
  212.  
  213.     def __init__(self, credentials_type):
  214.         self._credentials_type = credentials_type
  215.         self._protocols = (PROTO_TLS1_1, PROTO_TLS1_0, PROTO_SSL3)
  216.         self._kx_algorithms = self._default_kx_algorithms[credentials_type]
  217.         self._ciphers = (CIPHER_AES_128_CBC, CIPHER_3DES_CBC, CIPHER_ARCFOUR_128)
  218.         self._mac_algorithms = (MAC_SHA1, MAC_MD5, MAC_RMD160)
  219.         self._compressions = (COMP_NULL,)
  220.  
  221.     def _get_protocols(self):
  222.         return self._protocols
  223.     def _set_protocols(self, protocols):
  224.         self._protocols = ProtocolListValidator(protocols)
  225.     protocols = property(_get_protocols, _set_protocols)
  226.     del _get_protocols, _set_protocols
  227.  
  228.     def _get_kx_algorithms(self):
  229.         return self._kx_algorithms
  230.     def _set_kx_algorithms(self, algorithms):
  231.         cred_type = self._credentials_type
  232.         algorithms = KeyExchangeListValidator(algorithms)
  233.         invalid = set(algorithms) - self._all_kx_algorithms[cred_type]
  234.         if invalid:
  235.             raise ValueError("Cannot specify %r with %r credentials" % (list(invalid), cred_type))
  236.         self._kx_algorithms = algorithms
  237.     kx_algorithms = property(_get_kx_algorithms, _set_kx_algorithms)
  238.     del _get_kx_algorithms, _set_kx_algorithms
  239.  
  240.     def _get_ciphers(self):
  241.         return self._ciphers
  242.     def _set_ciphers(self, ciphers):
  243.         self._ciphers = CipherListValidator(ciphers)
  244.     ciphers = property(_get_ciphers, _set_ciphers)
  245.     del _get_ciphers, _set_ciphers
  246.  
  247.     def _get_mac_algorithms(self):
  248.         return self._mac_algorithms
  249.     def _set_mac_algorithms(self, algorithms):
  250.         self._mac_algorithms = MACListValidator(algorithms)
  251.     mac_algorithms = property(_get_mac_algorithms, _set_mac_algorithms)
  252.     del _get_mac_algorithms, _set_mac_algorithms
  253.  
  254.     def _get_compressions(self):
  255.         return self._compressions
  256.     def _set_compressions(self, compressions):
  257.         self._compressions = CompressionListValidator(compressions)
  258.     compressions = property(_get_compressions, _set_compressions)
  259.     del _get_compressions, _set_compressions
  260.  
  261.  
  262. class Session(object):
  263.     """Abstract class representing a TLS session created from a TCP socket
  264.        and a Credentials object."""
  265.  
  266.     session_type = None ## placeholder for GNUTLS_SERVER or GNUTLS_CLIENT as defined by subclass
  267.  
  268.     def __new__(cls, *args, **kwargs):
  269.         if cls is Session:
  270.             raise RuntimeError("Session cannot be instantiated directly")
  271.         instance = object.__new__(cls)
  272.         instance.__deinit = gnutls_deinit
  273.         instance._c_object = gnutls_session_t()
  274.         return instance
  275.  
  276.     def __init__(self, socket, credentials):
  277.         gnutls_init(byref(self._c_object), self.session_type)
  278.         ## Store a pointer to self on the C session
  279.         gnutls_session_set_ptr(self._c_object, id(self))
  280.         # gnutls_dh_set_prime_bits(session, DH_BITS)?
  281.         gnutls_transport_set_ptr(self._c_object, socket.fileno())
  282.         gnutls_handshake_set_private_extensions(self._c_object, 1)
  283.         self.socket = socket
  284.         self.credentials = credentials
  285.         self._update_params()
  286.  
  287.     def __del__(self):
  288.         self.__deinit(self._c_object)
  289.  
  290.     def __getattr__(self, name):
  291.         ## Generic wrapper for the underlying socket methods and attributes.
  292.         return getattr(self.socket, name)
  293.  
  294.     # Session properties
  295.  
  296.     def _get_credentials(self):
  297.         return self._credentials
  298.     @method_args(X509Credentials)
  299.     def _set_credentials(self, credentials):
  300.         ## Release all credentials, otherwise gnutls will only release an existing credential of
  301.         ## the same type as the one being set and we can end up with multiple credentials in C.
  302.         gnutls_credentials_clear(self._c_object)
  303.         gnutls_credentials_set(self._c_object, credentials._type, cast(credentials._c_object, c_void_p))
  304.         self._credentials = credentials
  305.     credentials = property(_get_credentials, _set_credentials)
  306.     del _get_credentials, _set_credentials
  307.  
  308.     @property
  309.     def protocol(self):
  310.         return gnutls_protocol_get_name(gnutls_protocol_get_version(self._c_object))
  311.  
  312.     @property
  313.     def kx_algorithm(self):
  314.         return gnutls_kx_get_name(gnutls_kx_get(self._c_object))
  315.  
  316.     @property
  317.     def cipher(self):
  318.         return gnutls_cipher_get_name(gnutls_cipher_get(self._c_object))
  319.  
  320.     @property
  321.     def mac_algorithm(self):
  322.         return gnutls_mac_get_name(gnutls_mac_get(self._c_object))
  323.  
  324.     @property
  325.     def compression(self):
  326.         return gnutls_compression_get_name(gnutls_compression_get(self._c_object))
  327.  
  328.     @property
  329.     def peer_certificate(self):
  330.         if gnutls_certificate_type_get(self._c_object) != GNUTLS_CRT_X509:
  331.             return None
  332.         list_size = c_uint()
  333.         cert_list = gnutls_certificate_get_peers(self._c_object, byref(list_size))
  334.         if list_size.value == 0:
  335.             return None
  336.         cert = cert_list[0]
  337.         return X509Certificate(string_at(cert.data, cert.size), X509_FMT_DER)
  338.  
  339.     # Status checking after an operation was interrupted (these properties are
  340.     # only useful to check after an operation was interrupted, otherwise their
  341.     # value is meaningless).
  342.  
  343.     @property
  344.     def interrupted_while_writing(self):
  345.         """True if an operation was interrupted while writing"""
  346.         return gnutls_record_get_direction(self._c_object)==1
  347.  
  348.     @property
  349.     def interrupted_while_reading(self):
  350.         """True if an operation was interrupted while reading"""
  351.         return gnutls_record_get_direction(self._c_object)==0
  352.  
  353.     # Session methods
  354.  
  355.     def _update_params(self):
  356.         """Update the priorities of the session params using the credentials."""
  357.         def c_priority_list(priorities):
  358.             size = len(priorities) + 1
  359.             return (c_int * size)(*priorities)
  360.         session_params = self.credentials.session_params
  361.         # protocol order in the priority list is irrelevant (it always uses newer protocols first)
  362.         # the protocol list only specifies what protocols are to be enabled.
  363.         gnutls_protocol_set_priority(self._c_object, c_priority_list(session_params.protocols))
  364.         gnutls_kx_set_priority(self._c_object, c_priority_list(session_params.kx_algorithms))
  365.         gnutls_cipher_set_priority(self._c_object, c_priority_list(session_params.ciphers))
  366.         gnutls_mac_set_priority(self._c_object, c_priority_list(session_params.mac_algorithms))
  367.         gnutls_compression_set_priority(self._c_object, c_priority_list(session_params.compressions))
  368.  
  369.     def handshake(self):
  370.         gnutls_handshake(self._c_object)
  371.  
  372.     #@method_args((basestring, buffer))
  373.     def send(self, data):
  374.         data = str(data)
  375.         return gnutls_record_send(self._c_object, data, len(data))
  376.  
  377.     def sendall(self, data):
  378.         size = len(data)
  379.         while size > 0:
  380.             sent = self.send(data[-size:])
  381.             size -= sent
  382.  
  383.     def recv(self, limit):
  384.         data = create_string_buffer(limit)
  385.         size = gnutls_record_recv(self._c_object, data, limit)
  386.         return data[:size]
  387.  
  388.     def send_alert(self, exception):
  389.         alertdict = {
  390.             CertificateError: GNUTLS_A_BAD_CERTIFICATE,
  391.             CertificateAuthorityError: GNUTLS_A_UNKNOWN_CA,
  392.             CertificateSecurityError: GNUTLS_A_INSUFFICIENT_SECURITY,
  393.             CertificateExpiredError: GNUTLS_A_CERTIFICATE_EXPIRED,
  394.             CertificateRevokedError: GNUTLS_A_CERTIFICATE_REVOKED}
  395.         alert = alertdict.get(exception.__class__)
  396.         if alert:
  397.             gnutls_alert_send(self._c_object, GNUTLS_AL_FATAL, alert)
  398.  
  399.     @method_args(one_of(SHUT_RDWR, SHUT_WR))
  400.     def bye(self, how=SHUT_RDWR):
  401.         gnutls_bye(self._c_object, how)
  402.  
  403.     def shutdown(self, how=SOCKET_SHUT_RDWR):
  404.         self.socket.shutdown(how)
  405.  
  406.     def close(self):
  407.         self.socket.close()
  408.  
  409.     def verify_peer(self):
  410.         status = c_uint()
  411.         gnutls_certificate_verify_peers2(self._c_object, byref(status))
  412.         status = status.value
  413.         if status & GNUTLS_CERT_INVALID:
  414.             raise CertificateError("peer certificate is invalid")
  415.         elif status & GNUTLS_CERT_SIGNER_NOT_FOUND:
  416.             raise CertificateAuthorityError("peer certificate signer not found")
  417.         elif status & GNUTLS_CERT_SIGNER_NOT_CA:
  418.             raise CertificateAuthorityError("peer certificate signer is not a CA")
  419.         elif status & GNUTLS_CERT_INSECURE_ALGORITHM:
  420.             raise CertificateSecurityError("peer certificate uses an insecure algorithm")
  421.         elif status & GNUTLS_CERT_REVOKED:
  422.             raise CertificateRevokedError("peer certificate was revoked")
  423.  
  424.  
  425. class ClientSession(Session):
  426.     session_type = GNUTLS_CLIENT
  427.  
  428.     def __init__(self, socket, credentials, server_name=None):
  429.         Session.__init__(self, socket, credentials)
  430.         self._server_name = None
  431.         if server_name is not None:
  432.             self.server_name = server_name
  433.  
  434.     def _get_server_name(self):
  435.         return self._server_name
  436.     @method_args(str)
  437.     def _set_server_name(self, server_name):
  438.         gnutls_server_name_set(self._c_object, GNUTLS_NAME_DNS, c_char_p(server_name), len(server_name))
  439.         self._server_name = server_name
  440.     server_name = property(_get_server_name, _set_server_name)
  441.     del _get_server_name, _set_server_name
  442.  
  443.  
  444. class ServerSession(Session):
  445.     session_type = GNUTLS_SERVER
  446.  
  447.     def __init__(self, socket, credentials):
  448.         Session.__init__(self, socket, credentials)
  449.         gnutls_certificate_server_set_request(self._c_object, CERT_REQUEST)
  450.  
  451.     @property
  452.     def server_name(self):
  453.         data_length = c_size_t(256)
  454.         data = create_string_buffer(data_length.value)
  455.         hostname_type = c_uint()
  456.         for i in xrange(2**16):
  457.             try:
  458.                 gnutls_server_name_get(self._c_object, data, byref(data_length), byref(hostname_type), i)
  459.             except RequestedDataNotAvailable:
  460.                 break
  461.             except MemoryError:
  462.                 data_length.value += 1 ## one extra byte for the terminating 0
  463.                 data = create_string_buffer(data_length.value)
  464.                 gnutls_server_name_get(self._c_object, data, byref(data_length), byref(hostname_type), i)
  465.             if hostname_type.value != GNUTLS_NAME_DNS:
  466.                 continue
  467.             return data.value
  468.         return None
  469.  
  470.  
  471. class ServerSessionFactory(object):
  472.  
  473.     def __init__(self, socket, credentials, session_class=ServerSession):
  474.         if not issubclass(session_class, ServerSession):
  475.             raise TypeError, "session_class must be a subclass of ServerSession"
  476.         self.socket = socket
  477.         self.credentials = credentials
  478.         self.session_class = session_class
  479.  
  480.     def __getattr__(self, name):
  481.         ## Generic wrapper for the underlying socket methods and attributes
  482.         return getattr(self.socket, name)
  483.  
  484.     def bind(self, address):
  485.         self.socket.bind(address)
  486.  
  487.     def listen(self, backlog):
  488.         self.socket.listen(backlog)
  489.  
  490.     def accept(self):
  491.         new_sock, address = self.socket.accept()
  492.         session = self.session_class(new_sock, self.credentials)
  493.         return (session, address)
  494.  
  495.     def shutdown(self, how=SOCKET_SHUT_RDWR):
  496.         self.socket.shutdown(how)
  497.  
  498.     def close(self):
  499.         self.socket.close()
  500.  
  501.